"""
Variants of models for tests.

"""
from tqdm import tqdm

import torch
from torch import nn
import torch.nn.functional as F
import dgl.nn as dglnn
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler


class SAGE(nn.Module):
    """GraphSAGE model"""
    def __init__(self, in_size, hid_size, out_size,
                 GNN_layer=3, dropout=0.3,
                 is_batch_norm=True,
                 bias=True):
        super().__init__()
        AGG = "gcn"
        self.layers = nn.ModuleList()
        self.bn = nn.ModuleList()
        if GNN_layer < 1:
            raise ValueError(f"Minimum GNN layer should be 1, current is {GNN_layer}")
        elif GNN_layer == 1:
            self.layers.append(dglnn.SAGEConv(in_size, out_size, AGG, bias=bias))
        else:
            self.layers.append(dglnn.SAGEConv(in_size, hid_size, AGG, bias=bias))
            if is_batch_norm:
                self.bn.append(torch.nn.BatchNorm1d(hid_size))
            else:
                self.bn.append(torch.nn.Identity())
            for _ in range(GNN_layer-2):
                self.layers.append(dglnn.SAGEConv(hid_size, hid_size, AGG, bias=bias))
                if is_batch_norm:
                    self.bn.append(torch.nn.BatchNorm1d(hid_size))
                else:
                    self.bn.append(torch.nn.Identity())
            self.layers.append(dglnn.SAGEConv(hid_size, out_size, AGG, bias=bias))

        self.dropout = nn.Dropout(dropout)
        self.in_size = in_size
        self.hid_size = hid_size
        self.out_size = out_size
        self.GNN_layer = GNN_layer
        self.is_batch_norm = is_batch_norm
        self.bias = bias

    def forward(self, blocks, x):
        """forward operation"""
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = self.bn[l](h)
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g, device,
                  batch_size, is_sample=False,
                  sample_size=2, return_embed=False,
                  feature_key='feat'):
        """Conduct layer-wise inference to get all the node embeddings."""
        feat = g.ndata[feature_key]
        if is_sample:
            sampler = NeighborSampler([sample_size], prefetch_node_feats=[feature_key])
        else:
            sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[feature_key])

        dataloader = DataLoader(g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                                batch_size=batch_size, shuffle=False, drop_last=False,
                                num_workers=0)
        buffer_device = 'cpu'
        pin_memory = (buffer_device == 'cpu')

        for l, layer in enumerate(self.layers):
            if l == len(self.layers) - 1 and return_embed == True:
                break # return the embed instead of the prediction
            y = torch.empty(
                g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size,
                device=buffer_device, pin_memory=pin_memory)
            feat = feat.to(device)
            for input_nodes, output_nodes, blocks in dataloader:
                x = feat[input_nodes]
                h = layer(blocks[0], x) # len(blocks) = 1
                if l != len(self.layers) - 1:
                    h = self.bn[l](h)
                    h = F.relu(h)
                    h = self.dropout(h)
                # by design, our output nodes are contiguous
                y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
            feat = y
        return y.to(device)

    def inference_teacher(self, g, device, batch_size):
        """Conduct layer-wise inference to get all the node embeddings as teacher for mlp students"""
        feat = g.ndata['feat']
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
        dataloader = DataLoader(g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                                batch_size=batch_size, shuffle=False, drop_last=False,
                                num_workers=0)
        buffer_device = 'cpu'
        pin_memory = (buffer_device == 'cpu')

        for l, layer in enumerate(self.layers):
            print(f"Inferencing layer {l}...")
            y = torch.empty(
                g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size,
                device=buffer_device, pin_memory=pin_memory)
            if l == len(self.layers) - 2:
                hid_embed = torch.empty(
                            g.num_nodes(), self.hid_size,
                            device=buffer_device, pin_memory=pin_memory)
            feat = feat.to(device)
            for input_nodes, output_nodes, blocks in dataloader:
                x = feat[input_nodes]
                h = layer(blocks[0], x) # len(blocks) = 1
                if l == len(self.layers) - 2:
                    hid_embed[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
                if l != len(self.layers) - 1:
                    h = self.bn[l](h)
                    h = F.relu(h)
                    h = self.dropout(h)
                # by design, our output nodes are contiguous
                y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
            feat = y
        y = y.log_softmax(dim=1).cpu()
        return hid_embed, y

class MLP(nn.Module):
    """MLP model"""
    def __init__(self, in_size,
                 list_hid_size,
                 out_size,
                 dropout=0.0,
                 dtype=None,
                 is_batch_norm=True,
                 bias=True):
        super().__init__()
        self.layers = nn.ModuleList()
        self.bn = nn.ModuleList()
        if list_hid_size is None or len(list_hid_size) == 0:
            self.layers.append(nn.Linear(in_size, out_size, dtype=dtype, bias=bias))
        else:
            self.layers.append(nn.Linear(in_size, list_hid_size[0], dtype=dtype, bias=bias))
            if is_batch_norm:
                self.bn.append(torch.nn.BatchNorm1d(list_hid_size[0], dtype=dtype))
            else:
                self.bn.append(torch.nn.Identity())
            for i in range(len(list_hid_size)-1):
                self.layers.append(nn.Linear(list_hid_size[i], list_hid_size[i+1], dtype=dtype, bias=bias))
                if is_batch_norm:
                    self.bn.append(torch.nn.BatchNorm1d(list_hid_size[i], dtype=dtype))
                else:
                    self.bn.append(torch.nn.Identity())
            self.layers.append(nn.Linear(list_hid_size[-1], out_size, dtype=dtype, bias=bias))
        self.dropout = nn.Dropout(dropout)
        self.list_hid_size = list_hid_size
        self.out_size = out_size
        self.dtype=dtype
        self.is_batch_norm = is_batch_norm
        self.bias = bias

    def forward(self, x):
        """forward"""
        for l, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.bn[l](x)
            x = F.relu(x)
            x = self.dropout(x)
        last_layer = self.layers[-1]
        x = last_layer(x)
        return x
